# Copyright 2021 The Kubeflow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Anomaly detection component using TensorFlow Probability."""
import kfp
from kfp.v2.dsl import Dataset
from kfp.v2.dsl import Input
from kfp.v2.dsl import Output


def tfp_anomaly_detection(input_dataset: Input[Dataset],
                          output_dataset: Output[Dataset],
                          time_col: str = 'timestamp',
                          feature_col: str = 'value',
                          num_samples: int = 50,
                          max_num_steps: int = 1000,
                          jit_compile: bool = True,
                          seed: int = None,
                          anomaly_threshold: float = 0.01) -> None:
  """Uses TFP STS to regularize a time series, fit a model, and predict anomalies.

  Args:
    input_dataset: Input with given gcs path to file.
    output_dataset: Output with autogenerated gcs path to output file.
    time_col: Name of csv column with timestamps.
    feature_col: Name of csv column with feature values.
    num_samples: Number of data points to sample from the posterior
      distribution.
    max_num_steps: Number of steps to run the optimizer.
    jit_compile: If True, compiles the loss function and gradient update using
      XLA.
    seed: The random seed to use for sampling.
    anomaly_threshold: Any data point with a pvalue lower than this threshold is
      labelled anomalous.
  """

  import collections
  import typing
  import itertools
  import numpy as np
  import pandas as pd
  import tensorflow.compat.v2 as tf
  import tensorflow_probability as tfp
  from tensorflow_probability.python.sts import default_model
  from tensorflow_probability.python.sts.internal import util as sts_util
  from tensorflow_probability.python.sts import one_step_predictive

  PredictionOutput = collections.namedtuple('PredictionOutput', [
      'time', 'all_times', 'observed_series', 'mean', 'lower_limit',
      'upper_limit', 'anomalies', 'pvalues'
  ])
  logger = tf.get_logger()

  # TODO: Implement batch processing so that long time series are processed in chunks
  def load_data(path: str) -> pd.DataFrame:
    """Loads pandas dataframe from csv.

    Args:
      path: Path to the csv file.

    Returns:
      A standardized time series dataframe.
    """
    df = pd.read_csv(path)
    time_index = pd.to_datetime(df[time_col])
    values = df[feature_col].astype('float64').tolist()
    standardized_df = pd.DataFrame(data={'value': values}, index=time_index)
    logger.info(
        'Input dataset has {0} rows. Note that if you run out of memory you should increase set_memory_limit in your pipeline.'
        .format(len(standardized_df)))
    return standardized_df

  def format_predictions(predictions: PredictionOutput) -> pd.DataFrame:
    """Saves predictions in a standardized csv format.

    Args:
      predictions: Anomaly detection output with fields all_times,
        observed_series, anomalies, pvalues, lower_limit, mean, upper_limit.

    Returns:
      predictions_df: A formatted pandas DataFrame compatible with scoring on
      the Numenta Anomaly Benchmark.
    """
    anomalies = set(predictions.anomalies)
    anomaly_scores = [1 - pvalue for pvalue in predictions.pvalues]
    labels = [idx in anomalies for idx in range(len(anomaly_scores))]

    predictions_df = pd.DataFrame(
        data={
            'timestamp': predictions.all_times,
            'value': predictions.observed_series,
            'anomaly_score': anomaly_scores,
            'pvalue': predictions.pvalues,
            'label': labels,
            'lower_limit': predictions.lower_limit,
            'mean': predictions.mean,
            'upper_limit': predictions.upper_limit,
        })
    return predictions_df

  def detect_anomalies(regularized_series: pd.Series,
                       model: tfp.sts.StructuralTimeSeries,
                       posterior_samples: typing.OrderedDict[str, tf.Tensor],
                       anomaly_threshold: float) -> PredictionOutput:
    """Given a model, posterior, and anomaly_threshold, identifies anomalous time points in a time series.

    Args:
      regularized_series: A time series with a regular frequency.
      model: A fitted model that approximates the time series.
      posterior_samples: Posterior samples of model parameters.
      anomaly_threshold: The anomaly threshold passed as a parameter in the
        outer function.

    Returns:
      time: Timestamps where each index aligns with the indices of the other
      outputs.
      all_times: Same as time for anomaly detection but contains more timestamps
      for forecasting.
      anomalies: Indices of timestamps with anomalies.
      lower_limit: Lowest forecast value computed when fitting the model.
      mean: Mean of the distribution computed when fitting the model.
      upper_limit: Highest forecast value computed when fitting the model.
      observed_series: The time series that was originally input.
      pvalues: For each point, the tail-area probability of the data point.
    """
    [observed_time_series, mask
    ] = sts_util.canonicalize_observed_time_series_with_mask(regularized_series)

    anomaly_threshold = tf.convert_to_tensor(
        anomaly_threshold, dtype=observed_time_series.dtype)

    # The one-step predictive distribution covers the final `T - 1` timesteps.
    predictive_dist = one_step_predictive(
        model,
        regularized_series[:-1],
        posterior_samples,
        timesteps_are_event_shape=False)
    observed_series = observed_time_series[..., 1:, 0]
    times = regularized_series.index[1:]

    # Compute the tail probabilities (pvalues) of the observed series.
    prob_lower = predictive_dist.cdf(observed_series)
    tail_probabilities = 2 * tf.minimum(prob_lower, 1 - prob_lower)

    # Since quantiles of a mixture distribution are not analytically available,
    # use scalar root search to compute the upper and lower bounds.
    predictive_mean = predictive_dist.mean()
    predictive_stddev = predictive_dist.stddev()
    target_log_cdfs = tf.stack(
        [
            tf.math.log(anomaly_threshold / 2.) * tf.ones_like(predictive_mean),  # Target log CDF at lower bound.
            tf.math.log1p(-anomaly_threshold / 2.) * tf.ones_like(predictive_mean)  # Target log CDF at upper bound.
        ],
        axis=0)
    limits, _, _ = tfp.math.find_root_chandrupatla(
        objective_fn=lambda x: target_log_cdfs - predictive_dist.log_cdf(x),
        low=tf.stack([
            predictive_mean - 100 * predictive_stddev,
            predictive_mean - 10 * predictive_stddev
        ],
                     axis=0),
        high=tf.stack([
            predictive_mean + 10 * predictive_stddev,
            predictive_mean + 100 * predictive_stddev
        ],
                      axis=0))

    # Identify anomalies.
    anomalies = np.less(tail_probabilities, anomaly_threshold)
    if mask is not None:
      anomalies = np.logical_and(anomalies, ~mask)
    observed_anomalies = list(
        itertools.compress(range(len(times)), list(anomalies)))

    times = times.strftime('%Y-%m-%d %H:%M:%S').tolist()
    observed_series = observed_series.numpy().tolist()
    predictive_mean = predictive_mean.numpy().tolist()
    lower_limit = limits[0].numpy().tolist()
    upper_limit = limits[1].numpy().tolist()
    tail_probabilities = tail_probabilities.numpy().tolist()

    return PredictionOutput(times, times, observed_series, predictive_mean,
                            lower_limit, upper_limit, observed_anomalies,
                            tail_probabilities)

  data = load_data(input_dataset.path)
  regularized_series = tfp.sts.regularize_series(data)

  model = default_model.build_default_model(regularized_series)
  surrogate_posterior = tfp.sts.build_factored_surrogate_posterior(
      model, seed=seed)
  _ = tfp.vi.fit_surrogate_posterior(
      target_log_prob_fn=model.joint_log_prob(regularized_series),
      surrogate_posterior=surrogate_posterior,
      optimizer=tf.optimizers.Adam(0.1),
      num_steps=max_num_steps,
      jit_compile=jit_compile,
      convergence_criterion=(
          tfp.optimizer.convergence_criteria.SuccessiveGradientsAreUncorrelated(
              window_size=20, min_num_steps=50)))

  posterior_samples = surrogate_posterior.sample(num_samples, seed=seed)

  predictions = detect_anomalies(regularized_series, model, posterior_samples,
                                 anomaly_threshold)
  predictions_df = format_predictions(predictions)
  predictions_df.to_csv(output_dataset.path)


def generate_component_file():
  packages = [
      'numpy', 'pandas', 'tf-nightly',
      'git+https://github.com/tensorflow/probability.git'
  ]
  kfp.components.create_component_from_func_v2(
      tfp_anomaly_detection,
      packages_to_install=packages,
      output_component_file='component.yaml')
